import warnings

import numpy as np

from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.callbacks.monitor_callback import (
    MonitorCallback,  # For metric monitoring logic
)
from keras.src.utils.io_utils import print_msg
from keras.src.utils.module_utils import ocp

# Context and AsyncOptions are accessed through the lazy-loaded ocp module

# JAX monitoring compatibility: ensure record_scalar exists
# to prevent AttributeError in older JAX versions
try:
    import jax

    if not hasattr(jax.monitoring, "record_scalar"):
        jax.monitoring.record_scalar = lambda *args, **kwargs: None
except ImportError:
    pass


def _get_state_tree(model):
    """Get the complete model state as a nested tree structure."""
    # For JAX backend, preserve native arrays for performance
    # For other backends, convert to numpy arrays
    if backend.backend() == "jax":
        state_tree = model.get_state_tree()
        did_numpy_conversion = False
    else:
        state_tree = model.get_state_tree(value_format="numpy_array")
        did_numpy_conversion = True

    # Convert numpy scalar types to Python types for Orbax compatibility
    # Only needed when we did numpy conversion
    if did_numpy_conversion:

        def convert_scalars(obj):
            if isinstance(obj, np.ndarray) and obj.ndim == 0:
                # Convert 0-dimensional numpy arrays (scalars) to Python types
                return obj.item()
            elif isinstance(obj, np.generic):
                # Convert numpy scalar types (like np.float32) to Python types
                return obj.item()
            else:
                return obj

        return tree.map_structure(convert_scalars, state_tree)
    else:
        return state_tree


@keras_export("keras.callbacks.OrbaxCheckpoint")
class OrbaxCheckpoint(MonitorCallback):
    """Callback to save and load model state using Orbax with a similar API to
    ModelCheckpoint.

    This callback saves the model's weights and optimizer state asynchronously
    using Orbax, allowing training to continue without blocking for I/O.

    Example:

    ```python
    model.compile(loss=..., optimizer=..., metrics=['accuracy'])

    EPOCHS = 10
    checkpoint_dir = '/tmp/ckpt'
    orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
        directory=checkpoint_dir,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True)

    # Model is saved at the end of every epoch, if it's the best seen so far.
    model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])

    # Alternatively, save checkpoints every N batches -
    orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
        directory=checkpoint_dir,
        save_freq=100)  # Save every 100 batches

    model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
    ```

    Args:
        directory: path to the directory where to save the checkpoints.
        monitor: The metric name to monitor (e.g., 'val_loss').
        verbose: Verbosity mode, 0 or 1.
        save_best_only: if `save_best_only=True`, it only saves when the model
            is considered the "best" based on the monitored quantity.
        save_weights_only: if `save_weights_only=True`, only the model's
            weights will be saved. Otherwise, the full model state
            (weights, non-trainable variables, optimizer state, and
            metrics state) will be saved. Defaults to False.
        mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
        save_freq: `'epoch'` or integer. Frequency to save checkpoints.
        max_to_keep: Integer, maximum number of recent checkpoints to keep.
            If None, keeps all. Defaults to 1.
        save_on_background: Boolean, whether to save asynchronously in the
            background. Defaults to True.
        initial_value_threshold: Floating point initial "best" value for the
            monitor, used with `save_best_only`.
    """

    def __init__(
        self,
        directory,
        monitor="val_loss",
        verbose=0,
        save_best_only=False,
        save_weights_only=False,
        mode="auto",
        save_freq="epoch",
        initial_value_threshold=None,
        max_to_keep=1,
        save_on_background=True,
    ):
        # Ensure orbax is available
        ocp.initialize()

        # Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
        # logic
        super().__init__(monitor, mode, initial_value_threshold)

        self.directory = directory
        self.verbose = verbose
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.save_freq = save_freq
        self.max_to_keep = max_to_keep
        self.save_on_background = save_on_background
        self._batches_seen_since_last_saving = 0
        self._last_batch_seen = 0
        self._current_epoch = 0  # Keep track of epoch
        self._total_batches_seen = 0  # Global batch counter for step tracking

        if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
            raise ValueError(
                f"Unrecognized save_freq: {self.save_freq}. "
                "Expected save_freq are 'epoch' or integer values"
            )

        # --- Orbax Checkpointer Setup (V1 API) ---
        policies = []
        if max_to_keep is not None:
            policies.append(
                ocp.training.preservation_policies.LatestN(max_to_keep)
            )

        # Use AnyPreservationPolicy to combine them.
        preservation_policy = None
        if policies:
            preservation_policy = (
                ocp.training.preservation_policies.AnyPreservationPolicy(
                    policies
                )
            )

        # Create the V1 Checkpointer with direct parameter passing
        # Orbax will handle directory creation on all processes as needed
        self.checkpointer = ocp.training.Checkpointer(
            directory=directory,
            preservation_policy=preservation_policy,
        )

    def _should_save_on_batch(self, batch):
        """Check if we should save on this batch."""
        if self.save_freq == "epoch":
            return False

        if batch <= self._last_batch_seen:  # New epoch.
            add_batches = batch + 1
        else:
            add_batches = batch - self._last_batch_seen
        self._batches_seen_since_last_saving += add_batches
        self._last_batch_seen = batch
        self._total_batches_seen += add_batches

        if self._batches_seen_since_last_saving >= self.save_freq:
            self._batches_seen_since_last_saving = 0
            return True
        return False

    def _save_checkpoint(self, step, logs=None):
        """Save a checkpoint at the given step."""

        # --- Prepare Composite State (Backend-Agnostic) ---
        state_tree = _get_state_tree(self.model)

        # Save the nested state structures directly (preserving layer
        # names and structure)
        if self.save_weights_only:
            composite_state = {
                "trainable_variables": state_tree["trainable_variables"],
            }
            if "non_trainable_variables" in state_tree:
                composite_state["non_trainable_variables"] = state_tree[
                    "non_trainable_variables"
                ]
        else:
            composite_state = state_tree

        # --- Save Logic (V1 API) ---
        # All processes participate in distributed checkpointing
        # Checkpointer is configured to save unconditionally when
        # save_pytree is called
        if self.verbose > 0:
            print_msg(
                f"OrbaxCheckpoint: Triggering async save for step {step}..."
            )

        # Use a single with statement. If context_options is empty,
        # Context() uses defaults.
        with ocp.Context():
            if self.save_on_background:
                self.checkpointer.save_pytree_async(step, composite_state)
            else:
                self.checkpointer.save_pytree(step, composite_state)

    def on_train_batch_end(self, batch, logs=None):
        if self._should_save_on_batch(batch):
            # Handle save_best_only logic for batch-level saving
            should_save = True
            if self.save_best_only:
                current = logs.get(self.monitor) if logs else None
                if current is None:
                    warnings.warn(
                        f"Can save best model only with {self.monitor} "
                        f"available, skipping save at batch {batch}.",
                        stacklevel=2,
                    )
                    should_save = False
                elif not self._is_improvement(current, self.best):
                    should_save = False
                else:
                    # Update best value when there's improvement
                    self.best = current

            if should_save:
                # Use global batch count for Orbax save step
                step = self._total_batches_seen
                self._save_checkpoint(step=step, logs=logs)

    def on_epoch_end(self, epoch, logs=None):
        self._current_epoch = epoch
        if self.monitor_op is None:
            self._set_monitor_op()  # From MonitorCallback

        # For save_freq="epoch", save at every epoch
        should_save = self.save_freq == "epoch"

        # Handle save_best_only logic
        if should_save and self.save_best_only:
            current = logs.get(self.monitor) if logs else None
            if current is None:
                warnings.warn(
                    f"Can save best model only with {self.monitor} available, "
                    f"skipping save at epoch {epoch}.",
                    stacklevel=2,
                )
                should_save = False
            elif not self._is_improvement(current, self.best):
                should_save = False
            else:
                # Update best value when there's improvement
                self.best = current

        if should_save:
            # Use epoch number as the step for Orbax save
            # Keras has already made the save decision - Checkpointer will
            # save unconditionally
            self._save_checkpoint(step=epoch, logs=logs)

    def on_train_end(self, logs=None):
        # Close the Checkpointer to ensure all pending saves complete
        try:
            self.checkpointer.close()
        except Exception:
            pass  # Ignore errors during cleanup

    def wait_until_finished(self):
        """Wait for any in-progress checkpoint operations to complete.
        This method blocks until all asynchronous checkpoint save operations
        have completed. It should be called before attempting to load
        checkpoints if there might be pending save operations.
        """
        # Wait for any async operations to complete
        if hasattr(self.checkpointer, "wait"):
            self.checkpointer.wait()
        else:
            # Fallback for older Orbax versions that don't have wait() method
            while self.checkpointer.is_saving_in_progress():
                import time

                time.sleep(0.1)
